import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

def Hmotional(thetas,J,K=None):
    '''
    Hamiltonian for the free motional model.
    Evaluates: - Sum over i=1,...,N, j=1,...,N of J_ij cos(theta_i + theta_j)
        + K_ij sin(theta_i + theta_j)
    
    :param thetas: Specification of all thetas, effectively the vector spanning Hilbert space.
    :type thetas: array-like
    :param J: Matrix of pairwise coupling strengths.
    :type J: N-by-N symmetric matrix
    :param K: Matrix of pairwise coupling strengths for the sine term.
    :type K: N-by-N symmetric matrix

    :returns: evaluation of motional Hamiltonian
    '''
    if K is None:
        return -np.einsum('ij,ij->',J,np.cos(np.add.outer(thetas,thetas)))
    return -np.einsum('ij,ij->',J,np.cos(np.add.outer(thetas,thetas))) \
           -np.einsum('ij,ij->',K,np.sin(np.add.outer(thetas,thetas)))

def Hmotional_grad(thetas,J,K=None):
    if K is None:
        return 2*np.einsum('ij,ij->i',J,np.sin(np.add.outer(thetas,thetas)))
    return 2*np.einsum('ij,ij->i',J,np.sin(np.add.outer(thetas,thetas))) \
           -2*np.einsum('ij,ij->i',K,np.cos(np.add.outer(thetas,thetas)))

def Hmotional_hess(thetas,J,K=None):
    M = 2*J*np.cos(np.add.outer(thetas,thetas))
    if K is not None: # presumably this works
        M += 2*K*np.sin(np.add.outer(thetas,thetas))
    return M + np.diag(np.sum(M,axis=0))

def motionalMetroSinJ(theta,J,K,steps,Tfunc=-1,offset=0,stepSize=np.pi/8):
    N = len(theta)
    Joff = J - np.diag(np.diag(J))
    Koff = K - np.diag(np.diag(K))
    
    if Tfunc == -1:
        Tfunc = lambda i: 0.05 + 10*np.exp(-i/(10*N))
    
    for i in range(steps):
        
        idx = np.random.randint(N)
        dTheta = np.random.randn()*stepSize
        dE = 0
        T = Tfunc(i+offset)

        dE  = -2*np.dot( Joff[idx,:] ,  np.cos(theta + theta[idx] + dTheta) - np.cos(theta + theta[idx]) )
        dE += -J[idx,idx]*(np.cos(2*theta[idx]+2*dTheta) - np.cos(2*theta[idx]))
        
        dE += -2*np.dot( Koff[idx,:] ,  np.sin(theta + theta[idx] + dTheta) - np.sin(theta + theta[idx]) )
        dE += -K[idx,idx]*(np.sin(2*theta[idx]+2*dTheta) - np.sin(2*theta[idx]))
        
        if np.random.rand() < np.exp(-dE/T):
            theta[idx] += dTheta
            
    return theta

def motionalEnergySin(theta,J,K):
    E = 0
    N = len(theta)
    for i in range(N):
        for j in range(i,N):
            
            if i==j:
                E += -J[i,j]*np.cos(theta[i]+theta[j]) - K[i,j]*np.sin(theta[i]+theta[j])
            else:
                E += -2*J[i,j]*np.cos(theta[i]+theta[j]) - 2*K[i,j]*np.sin(theta[i]+theta[j])
                
    return E

# Deterministic even/odd swap method.
# thetas should be of size (numTs,N)
# steps is actually the number of steps/2N
def motionalPTsin(J,K,Ts,steps,thetasA=-1,thetasB=-1,olaps=-1,mags=-1,stepSize=np.pi/8,nBins=50,square=True):
    N = J.shape[0]
    numTs = len(Ts)
    
    if thetasA==-1:
        thetasA = np.random.rand(numTs,N)*2*np.pi
        
    if thetasB==-1:
        thetasB = np.random.rand(numTs,N)*2*np.pi
    
    if olaps==-1:
        olaps = np.zeros((numTs,nBins,nBins),dtype=int)
        
    if mags==-1:
        mags = np.zeros((numTs,nBins,nBins),dtype=int) 
        
    val2idx = lambda val: int( np.round( ((val+1)/2)*(nBins-1) )) 
    
    for k in tqdm(range(steps)):
        
        # Even/odd swap counter
        for l in range(2):
        
            # For each temp, do metropolis and measurements    
            for i in range(numTs):
                
                # Do metropolis for each temperature
                Tfunc = lambda x: Ts[i]
                thetasA[i,:] = motionalMetroSinJ(thetasA[i,:],J,K,N,Tfunc=Tfunc,stepSize=stepSize)
                thetasB[i,:] = motionalMetroSinJ(thetasB[i,:],J,K,N,Tfunc=Tfunc,stepSize=stepSize)
                
                # Measurements
                mxA = np.sum(np.cos(thetasA[i,:]))/N
                myA = np.sum(np.sin(thetasA[i,:]))/N
                mxB = np.sum(np.cos(thetasB[i,:]))/N
                myB = np.sum(np.sin(thetasB[i,:]))/N
                qxx = np.dot( np.cos(thetasA[i,:]), np.cos(thetasB[i,:]) )/N
                qyy = np.dot( np.sin(thetasA[i,:]), np.sin(thetasB[i,:]) )/N
                
                # Store results in histograms
                mags [i,val2idx(mxA),val2idx(myA)] += 1
                mags [i,val2idx(mxB),val2idx(myB)] += 1
                
                if square:
                    olaps[i,val2idx(qxx+qyy),val2idx(qyy-qxx)] += 1
                else: 
                    olaps[i,val2idx(qxx),val2idx(qyy)] += 1

            # Even/odd swaps
            for i in range(l,numTs-1,2):
                
                E1A = motionalEnergySin(thetasA[i,:],J,K)
                E2A = motionalEnergySin(thetasA[i+1,:],J,K)
                
                if np.random.rand()< np.exp( (E2A-E1A)*(1/Ts[i+1]-1/Ts[i]) ):
                    tmp = thetasA[i,:]
                    thetasA[i,:] = thetasA[i+1,:]
                    thetasA[i+1,:] = tmp
                    
                E1B = motionalEnergySin(thetasB[i,:],J,K)
                E2B = motionalEnergySin(thetasB[i+1,:],J,K)
                
                if np.random.rand()< np.exp( (E2B-E1B)*(1/Ts[i+1]-1/Ts[i]) ):
                    tmp = thetasB[i,:]
                    thetasB[i,:] = thetasB[i+1,:]
                    thetasB[i+1,:] = tmp
    
    return thetasA,thetasB,olaps,mags

# The function will use largest eigenvalue normalization. You can provide the final
# temperature as T0 or a full annealing function Tfunc. The 'steps' argument is how many
# steps of metropolis you do (times N). 
def makeMotionalEnsemble(J,K,numStates=50,T0=0.05,Tfunc=-1,steps=400,**kwargs):
    
    N = J.shape[0]
    
    # Lets normalize J by largest eigenvalue
    emax = np.max(np.abs(np.linalg.eigvalsh(J)))
    Jnorm = J/emax
    Knorm = K/emax
    
    if Tfunc==-1:
        Tfunc = lambda i: T0 + 5*np.exp(-i/(10*N))
    
    allThetas = np.zeros((numStates,N))
    
    # Run all the trials
    for i in tqdm(range(numStates)):
        theta = np.random.rand(N)*2*np.pi
        theta = motionalMetroSinJ(theta,Jnorm,Knorm,steps*N,Tfunc=Tfunc,**kwargs)
        allThetas[i,:] = theta
        
    return allThetas

# Code Example
# generate some instance, normally this is loaded from a file
J = np.random.normal(size=(8,8))
J = 0.5 * (J + J.T)

K = np.random.normal(scale=0.5,size=(8,8))
K = 0.5 * (K + K.T)

# Normalize the Js somehow approximately
eMax = np.amax(np.abs(np.linalg.eigvalsh(J)))
J = J/eMax
K = K/eMax

# some settings
numTs = 10
Ts = np.geomspace(0.01,2,numTs)
steps = 2000
nBins = 80

# run it
np.random.seed(0)
thetasA,thetasB,olaps,mags = motionalPTsin(J,K,Ts,steps,stepSize=np.pi/8,nBins=nBins,square=True)

# make a plot
ylimPT = 50

fig,axs = plt.subplots(2,5,figsize=(10,4),sharex=True,sharey=True)
axs = axs.flatten()
for i in range(numTs):
    o = olaps[i,:,:]
    o = o + np.flipud(np.fliplr(o)) # symmetrize
    
    ax = axs[i]
    ax.imshow(o.T,vmax=ylimPT,extent=[-1,1,-1,1],origin="lower")
    ax.set_aspect(1)
    ax.set_xticks(np.linspace(-1,1,3))
    ax.set_yticks(np.linspace(-1,1,3))
    ax.set_title("T="+str(np.round(Ts[i],3)))